import torch

from torch import nn

if __name__ == '__main__':
    x = torch.rand((2, 3, 3))

    x = x[None, :1, :2]
    print(x.shape)

    print(x.repeat([3, 1, 1, 1]).shape)
